{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.1\n",
      "IPython 6.0.0\n",
      "\n",
      "tensorflow 1.2.0\n",
      "numpy 1.13.0\n",
      "h5py 2.7.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p tensorflow,numpy,h5py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Storing an Image Dataset for Minibatch Training using HDF5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook provides an example for how to save a large dataset of images as Hierarchical Data Format (HDF) for quick access during minibatch learning. This approach uses the common [HDF5](https://support.hdfgroup.org/HDF5/) format and should be accessible to any programming language or tool with an HDF5 API.\n",
    "\n",
    "While this approach performs reasonably well (sufficiently well for my applications), you may also be interested in TensorFlow's \"[Reading Data](https://www.tensorflow.org/programmers_guide/reading_data)\" guide to work with `TfRecords` and file queues."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. The Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's pretend we have a directory of images containing two subdirectories with images for training, validation, and testing. The following function will create such a dataset of images in JPEG format locally for demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./train-images-idx3-ubyte.gz\n",
      "Extracting ./train-labels-idx1-ubyte.gz\n",
      "Extracting ./t10k-images-idx3-ubyte.gz\n",
      "Extracting ./t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Note that executing the following code \n",
    "# cell will download the MNIST dataset\n",
    "# and save all the 60,000 images as separate JPEG\n",
    "# files. This might take a few minutes depending\n",
    "# on your machine.\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# load utilities from ../helper.py\n",
    "import sys\n",
    "sys.path.insert(0, '..') \n",
    "from helper import mnist_export_to_jpg\n",
    "\n",
    "np.random.seed(123)\n",
    "mnist_export_to_jpg(path='./')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist_train subdirectories ['.DS_Store', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n",
      "mnist_valid subdirectories ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n",
      "mnist_test subdirectories ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "for i in ('train', 'valid', 'test'):\n",
    "    print('mnist_%s subdirectories' % i, os.listdir('mnist_%s' % i))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the names of the subdirectories correspond directly to the class label of the images that are stored under it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make sure that the images look okay, the snippet below plots an example image from the subdirectory `mnist_train/9/`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEkpJREFUeJzt3W2MVGWWB/D/4cUGupGX7V5AB22MZlQwMkmFbMSY0XFG\nxkyCYxSHmJE1ZtoP7GQnmZhVNsbXGLPuQPywkjALGdjMOrMRVDS6K5AJZgKZ0BBWcFyVxcYBge4G\nku7mRWj67Ie+uK32PaetW3VvFef/SwjddepWPV1d/67qPvd5HlFVEFE8o4oeABEVg+EnCorhJwqK\n4ScKiuEnCorhJwqK4ScKiuEnCorhJwpqTJ531tzcrK2trXneJVFVeGfGikhOI/myjo4OdHd3j+jO\nM4VfRBYAeBHAaAD/qqrPW9dvbW1Fe3t7lrusS+fPnzfr3hNl1KjqvUEr8klc7fu2bj/rbff395v1\nMWNyfV39QqlUGvF1y35WichoAP8C4IcArgewWESuL/f2iChfWV5S5gHYp6r7VfUsgN8BWFiZYRFR\ntWUJ/+UA/jLk84PJZV8iIm0i0i4i7V1dXRnujogqqep/7VfVVapaUtVSS0tLte+OiEYoS/gPAZg5\n5PNvJZcRUR3IEv4dAK4RkVkicgmAnwDYWJlhEVG1ld2PUNV+Efk7AP+FwVbfGlV9v2Iju4iMHj06\n0/Feq3BgYCC1NnbsWPPYrC2vc+fOmXXr/k+ePGke642tsbGx7OO9x9RrrxbVyqukTF+Bqr4F4K0K\njYWIcsTTe4mCYviJgmL4iYJi+ImCYviJgmL4iYKq/2ZlHejr6zPrTU1NZt07T8Cqnzp1KtNtNzQ0\nmHXvPALrHATv686qt7c3teZ9XVnPzagHfOUnCorhJwqK4ScKiuEnCorhJwqK4ScKiq2+HGRtaWVZ\n5XbChAmZ7tvjTY21xt7T01P2sQAwZcoUsz5x4kSzbjlz5oxZ91qBXgu0FvCVnygohp8oKIafKCiG\nnygohp8oKIafKCiGnygo9vlz4PXCvR1fvZ5ylmWkT58+bda9Xrs3NdYa2+TJk81jPWfPnjXrVq/d\nWxZ83LhxZY2pnvCVnygohp8oKIafKCiGnygohp8oKIafKCiGnyioTH1+EekA0AvgPIB+VS1VYlAX\nm1peBtrr03tbVXusefHV7qVn2X7cO//BG3vWrc/zUImTfG5V1e4K3A4R5Yhv+4mCyhp+BbBZRHaK\nSFslBkRE+cj6tv9mVT0kIn8NYJOI/I+qvjv0CskPhTYAuOKKKzLeHRFVSqZXflU9lPzfCeBVAPOG\nuc4qVS2paqmlpSXL3RFRBZUdfhFpFJGJFz4G8AMAeys1MCKqrixv+6cBeDVpaYwB8O+q+p8VGRUR\nVV3Z4VfV/QBurOBYLlrnzp0z69Vc4/3AgQNmffPmzWZ99erVZn379u1m3Vo739pCeyQeeeQRs/7U\nU0+l1saPH28em2WvhHrBVh9RUAw/UVAMP1FQDD9RUAw/UVAMP1FQXLo7B1mW1gaA48ePm/WXXnop\ntbZ27Vrz2P3795t1rw3Z2Nho1q12nnfGZ1dXl1l/4YUXzLq1JPrjjz9uHutt/+0tx17L07gv4Cs/\nUVAMP1FQDD9RUAw/UVAMP1FQDD9RUAw/UVC59vlV1Zze6vWUrd6q11c9deqUWZ8wYYJZt2SdsvvZ\nZ5+Z9dtuu82sf/jhh2bd4m2Tfeutt5r1+fPnm/XZs2eXfaw33fjuu+826xs2bEitPfPMM+ax3nRj\na6pyveArP1FQDD9RUAw/UVAMP1FQDD9RUAw/UVAMP1FQufb5RcTseVvzrwF7XvzZs2fNY7NuB33y\n5MnUmjenfceOHWb9jjvuMOve12Z57LHHzPqyZcvMelNTk1n3xmbVvdv2eI+7de7HsWPHzGO9reWK\nXI69UvjKTxQUw08UFMNPFBTDTxQUw08UFMNPFBTDTxSU2+cXkTUAfgSgU1XnJJdNBfB7AK0AOgAs\nUtUTWQdz5swZs271hS+55JJM9+31q62e8u7du81jb7nlFrPures/apT9M3rTpk2ptdtvv908tru7\n26x7vXjvcc/yfdm6datZP336tFm/7rrrUmveuvyeeujje0byyv8bAAu+ctmjALao6jUAtiSfE1Ed\nccOvqu8C+OqWMQsBXNgKZi2Auyo8LiKqsnJ/55+mqoeTj48AmFah8RBRTjL/wU9VFYCm1UWkTUTa\nRaTd23uNiPJTbviPisgMAEj+70y7oqquUtWSqpa8jRmJKD/lhn8jgCXJx0sAvF6Z4RBRXtzwi8jL\nALYD+LaIHBSRhwA8D+D7IvIxgNuTz4mojrh9flVdnFL6XoXH4vazBwYGyj7Wm3/t9dqt459++mnz\nWG9PAe++vX73DTfcYNYtzc3NZr2np8esX3rppWXf97Zt28z6+vXrzbr1fACAGTNmpNa8dfe98z6y\nnldSC3iGH1FQDD9RUAw/UVAMP1FQDD9RUAw/UVC5Lt0NDG7TncZbXts6NiuvVfjaa6+l1t544w3z\nWG/654MPPmjWr732WrMuIqm1LMuhA/7W5d735MSJ9Jnezz77rHnswYMHzbo33dg6o9Qbt9fK6+vr\nM+tZlyXPA1/5iYJi+ImCYviJgmL4iYJi+ImCYviJgmL4iYLKvc9v8Xrtlt7eXrPuTeH0vPLKK6k1\nr1c+efJks758+fKyxjQSnZ2piywBAC677DKz7n1t3nkETzzxRGrt7bffNo9taGgw695U6fvuuy+1\n5k3Z9e7bq9cDvvITBcXwEwXF8BMFxfATBcXwEwXF8BMFxfATBVVT8/mteekebxnnrKZPn55aGz9+\nvHmsN3f8k08+MeuzZs0y6xavj+9ti+6tsfDAAw+YdW+tA4v3fPAe1xtvvLHs+/a2//a+5/WAr/xE\nQTH8REEx/ERBMfxEQTH8REEx/ERBMfxEQbl9fhFZA+BHADpVdU5y2ZMAfgagK7naMlV9y7stVc20\nzbY1B3vSpEne3Zu8ddjvvffe1NqKFSsy3XepVDLr8+fPN+v33HNPam327NnmsevWrTPrK1euNOve\n1udz5sxJrXlr2+/Zs8ese+skWGPz9lLw1gq4GIzklf83ABYMc/kKVZ2b/HODT0S1xQ2/qr4L4HgO\nYyGiHGX5nf/nIvKeiKwRkSkVGxER5aLc8K8EcBWAuQAOA/hV2hVFpE1E2kWkvbu7u8y7I6JKKyv8\nqnpUVc+r6gCAXwOYZ1x3laqWVLXU3Nxc7jiJqMLKCr+IzBjy6Y8B7K3McIgoLyNp9b0M4LsAmkXk\nIIAnAHxXROYCUAAdAB6u4hiJqArc8Kvq4mEuXl2FsbisPdO9frO3/rzXc7766qtTa9u3bzePXbp0\nqVnftWuXWffmxL/zzjupNW/Ou7d+vTdv/eGH7Z/7zz33XFk1ANi3b59Z99Y5yLI+hPVcA4Dz58+b\n9Xo4T4Bn+BEFxfATBcXwEwXF8BMFxfATBcXwEwVVU1t0V1OWtg8AtLS0pNYaGxvNY3fu3GnW33zz\nTbPe0dFh1q1W36effmoee//995t1ayozALS2tpr1EydOpNa2bdtmHuudDn7TTTeZda+9a/Fax96U\n4HrAV36ioBh+oqAYfqKgGH6ioBh+oqAYfqKgGH6ioHLt84tIpt6rxZtC2d/fb9azjGvChAllHwsA\nCxYMtzjy//PG1tbWllrzpvQ2NDSYdY93+/v370+teUtzexYtWlT2sd64vfrFgK/8REEx/ERBMfxE\nQTH8REEx/ERBMfxEQTH8REFdNPP5ve29vfn83hLW1vFZ53Z7y0B7Y/eWmc7iyJEjZn369Olm3Vp2\n/OTJk+ax1vbeADB37lyzbvXqvT6+95haW80D/vOxFtT+CImoKhh+oqAYfqKgGH6ioBh+oqAYfqKg\nGH6ioNw+v4jMBLAOwDQACmCVqr4oIlMB/B5AK4AOAItUNX2R9hHw+t1W79TrhWftlVdzfvepU6fM\n+pQpU8q+7WPHjpn1qVOnmnWvj+/16rdu3Zpa83rlV155pVn31miw1njwjvWeD/XQx/eM5CvoB/BL\nVb0ewN8AWCoi1wN4FMAWVb0GwJbkcyKqE274VfWwqu5KPu4F8AGAywEsBLA2udpaAHdVa5BEVHnf\n6L2LiLQC+A6APwGYpqqHk9IRDP5aQER1YsThF5EmAOsB/EJVe4bWdPAX4mF/KRaRNhFpF5H2rq6u\nTIMlosoZUfhFZCwGg/9bVd2QXHxURGYk9RkAOoc7VlVXqWpJVUvWZpdElC83/DL4Z/LVAD5Q1eVD\nShsBLEk+XgLg9coPj4iqZSRTeucD+CmAPSKyO7lsGYDnAfyHiDwE4ACA8tdRTnjtE6sV6B2btTWT\ndYtvS5ZWnmf8+PFmPevX5W1Pbn3Pxo0bZx7rtdtOnz5t1idOnJha87bgruY06Vrhhl9V/wgg7Rny\nvcoOh4jyUv9nKhBRWRh+oqAYfqKgGH6ioBh+oqAYfqKgamrpbq/nnGWKpjdd2Nvi2+JNTfXqXk/5\nzJkzZt0au7d9uNfv9qbsTpo0yaxbS6J73zNvGrV3DoPF+3573zNvbFmeT3nhKz9RUAw/UVAMP1FQ\nDD9RUAw/UVAMP1FQDD9RULn2+QcGBsw52FnmnmfdJjsLb60Aryf8+eefm3Vv3rvF23rce9wmT55s\n1g8cOGDWP/roo9Sad/6Cdx7AmDH209d63LM8poA/tnrAV36ioBh+oqAYfqKgGH6ioBh+oqAYfqKg\nGH6ioHLt848aNSrTHOx6lXV78Cz97qzrz/f19Zn1hoYGs27N9z9xwt7R3dsTwGM97lnPIfDq9YCv\n/ERBMfxEQTH8REEx/ERBMfxEQTH8REEx/ERBuc1KEZkJYB2AaQAUwCpVfVFEngTwMwBdyVWXqepb\n1RroxczbU8BbD8BaY95ba8DjnZfhrU9v9eqnT59uHjtz5kyznmUvhqx9em+dhKznV+RhJI9AP4Bf\nquouEZkIYKeIbEpqK1T1n6s3PCKqFjf8qnoYwOHk414R+QDA5dUeGBFV1zd6TygirQC+A+BPyUU/\nF5H3RGSNiExJOaZNRNpFpL2rq2u4qxBRAUYcfhFpArAewC9UtQfASgBXAZiLwXcGvxruOFVdpaol\nVS21tLRUYMhEVAkjCr+IjMVg8H+rqhsAQFWPqup5VR0A8GsA86o3TCKqNDf8Mjg1ajWAD1R1+ZDL\nZwy52o8B7K388IioWkby1/75AH4KYI+I7E4uWwZgsYjMxWD7rwPAw1UZYQBeOy7rdtFZeK08rxW4\nd2/6a0JPT495rLe8tjc2a0l0r9Xn3XbWFmotGMlf+/8IYLiJ0ezpE9Wx+v/xRURlYfiJgmL4iYJi\n+ImCYviJgmL4iYKq//WHLwJezzjLFuDeOQLesuJePYumpiaz7p2/4E3pterekuMeLt1NRHWL4ScK\niuEnCorhJwqK4ScKiuEnCorhJwpKqjkX/Gt3JtIF4MCQi5oBdOc2gG+mVsdWq+MCOLZyVXJsV6rq\niNbLyzX8X7tzkXZVLRU2AEOtjq1WxwVwbOUqamx8208UFMNPFFTR4V9V8P1banVstTougGMrVyFj\nK/R3fiIqTtGv/ERUkELCLyILRORDEdknIo8WMYY0ItIhIntEZLeItBc8ljUi0ikie4dcNlVENonI\nx8n/w26TVtDYnhSRQ8ljt1tE7ixobDNF5A8i8mcReV9E/j65vNDHzhhXIY9b7m/7RWQ0gI8AfB/A\nQQA7ACxW1T/nOpAUItIBoKSqhfeEReQWAH0A1qnqnOSyfwJwXFWfT35wTlHVf6iRsT0JoK/onZuT\nDWVmDN1ZGsBdAP4WBT52xrgWoYDHrYhX/nkA9qnqflU9C+B3ABYWMI6ap6rvAjj+lYsXAlibfLwW\ng0+e3KWMrSao6mFV3ZV83Avgws7ShT52xrgKUUT4LwfwlyGfH0RtbfmtADaLyE4RaSt6MMOYlmyb\nDgBHAEwrcjDDcHduztNXdpaumceunB2vK41/8Pu6m1V1LoAfAliavL2tSTr4O1sttWtGtHNzXobZ\nWfoLRT525e54XWlFhP8QgJlDPv9WcllNUNVDyf+dAF5F7e0+fPTCJqnJ/50Fj+cLtbRz83A7S6MG\nHrta2vG6iPDvAHCNiMwSkUsA/ATAxgLG8TUi0pj8IQYi0gjgB6i93Yc3AliSfLwEwOsFjuVLamXn\n5rSdpVHwY1dzO16rau7/ANyJwb/4/y+AfyxiDCnjugrAfyf/3i96bABexuDbwHMY/NvIQwD+CsAW\nAB8D2Axgag2N7d8A7AHwHgaDNqOgsd2Mwbf07wHYnfy7s+jHzhhXIY8bz/AjCop/8CMKiuEnCorh\nJwqK4ScKiuEnCorhJwqK4ScKiuEnCur/ABj13FTNxtqQAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x11ed24550>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.image as mpimg\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "some_img = os.path.join('./mnist_train/9/', os.listdir('./mnist_train/9/')[0])\n",
    "\n",
    "img = mpimg.imread(some_img)\n",
    "print(img.shape)\n",
    "plt.imshow(img, cmap='binary');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: The JPEG format introduces a few artifacts that we can see in the image above. In this case, we use JPEG instead of PNG. Here, JPEG is used for demonstration purposes since that's still format many image datasets are stored in."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Saving images as HDF5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following wrapper function creates .h5 file containing training, testing, and validation datasets. It will group images together into larger integer arrays that are then saved as subgroups in the HDF5 file. For instance, the training images will be saved as `'train/images'` and the corresponding labels as `'train/labels'` subgroup."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import h5py\n",
    "import glob\n",
    "\n",
    "\n",
    "def images_to_h5(data_stempath='./mnist_',\n",
    "                 width=28, height=28, channels=1,\n",
    "                 shuffle=False, random_seed=None):\n",
    "    \n",
    "    with h5py.File('mnist_batches.h5', 'w') as h5f:\n",
    "    \n",
    "        for s in ['train', 'valid', 'test']:\n",
    "            img_paths = [p for p in glob.iglob('%s%s/**/*.jpg' % \n",
    "                                       (data_stempath, s), \n",
    "                                        recursive=True)]\n",
    "\n",
    "            dset1 = h5f.create_dataset('%s/images' % s, \n",
    "                                       shape=[len(img_paths), \n",
    "                                              width, height, channels], \n",
    "                                       compression=None,\n",
    "                                       dtype='uint8')\n",
    "            dset2 = h5f.create_dataset('%s/labels' % s, \n",
    "                                       shape=[len(img_paths)], \n",
    "                                       compression=None,\n",
    "                                       dtype='uint8')\n",
    "            dset3 = h5f.create_dataset('%s/file_ids' % s, \n",
    "                                       shape=[len(img_paths)], \n",
    "                                       compression=None,\n",
    "                                       dtype='S5')\n",
    "            \n",
    "            rand_indices = np.arange(len(img_paths))\n",
    "            \n",
    "            if shuffle:\n",
    "                rng = np.random.RandomState(random_seed)\n",
    "                rng.shuffle(rand_indices)\n",
    "\n",
    "            for idx, path in enumerate(img_paths):\n",
    "\n",
    "                rand_idx = rand_indices[idx]\n",
    "                label = int(os.path.basename(os.path.dirname(path)))\n",
    "                image = mpimg.imread(path)\n",
    "                dset1[rand_idx] = image.reshape(width, height, channels)\n",
    "                dset2[rand_idx] = label\n",
    "                dset3[rand_idx] = np.array([os.path.basename(path)], dtype='S6')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that we didn't specify any compression format. The reason is that non-compressed HDF5 datasets are much faster to read, which is an important factor for training deep learning systems. In this case, the dataset is about ~47 Mb in size. However, we are working with larger datasets, compressing the HDF5 dataset might be one easy way to deal with hardware storage limitations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "images_to_h5(shuffle=True, random_seed=123)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check that the archiving worked correctly, we will now load the training images and print the array shape. Note that we can now access each archive similar to a python dictionary. Here the `'data'` key contains the image data and the `'labels'` key stores an array containing the corresponding class labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(45000, 28, 28, 1)\n",
      "(45000,)\n",
      "(45000,)\n"
     ]
    }
   ],
   "source": [
    "with h5py.File('mnist_batches.h5', 'r') as h5f:\n",
    "    print(h5f['train/images'].shape)\n",
    "    print(h5f['train/labels'].shape)\n",
    "    print(h5f['train/file_ids'].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class label: 8\n",
      "File ID: b'20829'\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEb5JREFUeJzt3W2MVGWWB/D/oXlrGhSwezvo4ALRmKBmwVRwyZANOgvx\nhQTng2bQDBjJAMlIFgNk1U1cPxhD1oWRBDOGWXDQzMBsMiBoyG6ArDEQnNBIjy+jjmB6BAL9Iq10\nQ0O/nf3QF9Ng3/M0davq3urz/yWkq+vU7Xq66H9XdZ37PI+oKojIn2FpD4CI0sHwEznF8BM5xfAT\nOcXwEznF8BM5xfATOcXwEznF8BM5NbyUd1ZdXa1Tpkwp5V26kOQsTREp6n1b9WHD7Oeenp4es15R\nUWHWLb29vWY9NLasamhoQEtLy6D+UxOFX0QeALARQAWA/1LVddbtp0yZgrq6uiR3SQPo7OyMrYV+\niIcPT/b7v6ury6xbY6uqqjKPPX/+vFm/4YYbzLrl4sWLZn306NFmPau/HHK53KBvm/d3ICIVAF4D\n8CCA6QAWicj0fL8eEZVWkl9fswAcV9WvVLUTwA4ACwszLCIqtiThvwXAyX6fn4quu4qILBOROhGp\na25uTnB3RFRIRf/DRVU3q2pOVXM1NTXFvjsiGqQk4T8NYHK/z38UXUdEZSBJ+I8AuF1EporISAA/\nA7CnMMMiomLLu8+jqt0i8jSA/0Vfq2+rqn5asJENIaGecqhXHupnjxw5MrbW3d1tHhtqp40aNSpR\nfcSIEWY9CauNCNiPy5gxYwo9nLKTqMmrqnsB7C3QWIiohLJ5pgIRFR3DT+QUw0/kFMNP5BTDT+QU\nw0/kVEnn83sVmjMfqofOE7CEpuwmmRYLhMfW0dERWwtN6U06Nuv8idC4Q2sJFHuqdCnwmZ/IKYaf\nyCmGn8gphp/IKYafyCmGn8ip7PcjhoBQKy80pTe0Qm4xp80mHZvVzgu10y5dumTWQ+26cePGxdZC\n06RD9cuXL5t1tvqIKLMYfiKnGH4ipxh+IqcYfiKnGH4ipxh+Iqey34x0INQrD/WcremloT79d999\nZ9YrKyvNemjpbqsfHjo2xOrjh7S2tpp1a9lvIDwduRzwmZ/IKYafyCmGn8gphp/IKYafyCmGn8gp\nhp/IqUR9fhFpANAGoAdAt6rmCjGooSY09ztpv9vaqjo013/8+PGJ7ru9vd2sjx07NrYW6rW//fbb\nZv3IkSNm/fDhw7G1lpYW89h58+aZ9a1bt5r1clCIk3zuU1X7kSSizOHLfiKnkoZfAewXkaMisqwQ\nAyKi0kj6sn+Oqp4Wkb8DsE9EPlfV9/vfIPqlsAwAbr311oR3R0SFkuiZX1VPRx+bAOwCMGuA22xW\n1Zyq5mpqapLcHREVUN7hF5EqERl35TKA+QA+KdTAiKi4krzsrwWwK1qWejiA36vq/xRkVERUdHmH\nX1W/AvAPBRzLkBXazjkpa85+aM+A0Nr5hw4dMuvvvvuuWf/ggw9iaw0NDeaxJ0+eNOvFtHPnTrP+\n6quvmvWk24uXAlt9RE4x/EROMfxETjH8RE4x/EROMfxETnHp7hIITatta2sz66FlpK0pwd3d3eax\nW7ZsMeuvv/66Wa+vrzfrVpsz9Lg8+OCDZn3WrB+cUHqVc+fOxdbOnj1rHrt7926zPmbMGLNeDvjM\nT+QUw0/kFMNP5BTDT+QUw0/kFMNP5BTDT+QU+/wl0NHRYdaTbDUNACdOnIitrV271jx2165dZn3i\nxIlmffbs2Wb98ccfj60tXrzYPDY0LTa0dPeOHTtiay+//LJ57L333mvWhw8v/+jwmZ/IKYafyCmG\nn8gphp/IKYafyCmGn8gphp/IqfJvVpaByspKs97Y2GjWv/jiC7O+fPny2Nrnn39uHhs6D2Dp0qVm\n/Y477jDrSTQ1NZn10PLa27dvj62tX7/ePHb16tVmPbTkeUVFhVnPAj7zEznF8BM5xfATOcXwEznF\n8BM5xfATOcXwEzkV7POLyFYACwA0qepd0XUTAfwBwBQADQAeU9XW4g2zvLW0tJj12tpas75p0yaz\nfvr06djaK6+8Yh77xBNPmPVJkyaZ9VC/u729PbYWWrc/9H2vW7fOrE+dOjW2duHCBfNYay8EAOjt\n7TXrQ6XP/1sAD1xz3bMADqjq7QAORJ8TURkJhl9V3wdw7dYnCwFsiy5vA/BIgcdFREWW79/8tap6\nJrp8FoD9upWIMifxG36qqgA0ri4iy0SkTkTqmpubk94dERVIvuFvFJFJABB9jJ2BoaqbVTWnqrma\nmpo8746ICi3f8O8BsCS6vASAvaUpEWVOMPwish3AYQB3iMgpEVkKYB2AeSLyJYB/jj4nojIS7POr\n6qKY0k8KPJYhq7q62qxv2LDBrL/00ktmfd68ebG1NWvWmMcmdfHiRbO+b9++2Np7771nHvvaa6+Z\n9VAv/rnnnoutVVVVmcd2dXWZ9ZEjR5r1csAz/IicYviJnGL4iZxi+ImcYviJnGL4iZzi0t0lEGob\nhZbXHj9+fCGHc11C02q3bNli1uvr6ws5nKs8+eSTZv2pp57K+2uHpmGHpjqXAz7zEznF8BM5xfAT\nOcXwEznF8BM5xfATOcXwEznFPn8J9K10Fu/GG280699++61ZP3jwYGztrbfeMo89duyYWQ/18c+f\nP2/Wramz8+fPN49duXKlWb/vvvvMekdHR2wtNCU31Me/fPmyWQ9NN84CPvMTOcXwEznF8BM5xfAT\nOcXwEznF8BM5xfATOcU+fwmEesoPP/ywWd+/f79Zt+bML1682Dw2ZMKECWY91A/ftm1bbM1acnww\nzp27dv/Yq40ZMya2FtpCu7Oz06yXQx8/hM/8RE4x/EROMfxETjH8RE4x/EROMfxETjH8RE4F+/wi\nshXAAgBNqnpXdN2LAH4BoDm62fOqurdYgxzqZs+ebdZXrFhh1teuXRtba2try2tMV7S2tpr1jRs3\nmnWrl9/b22seO2yY/dw0ceJEs55EaGxDwWCe+X8L4IEBrv+Vqs6I/jH4RGUmGH5VfR+AfSoVEZWd\nJH/zrxSRj0Rkq4jY54ASUebkG/5fA5gGYAaAMwDWx91QRJaJSJ2I1DU3N8fdjIhKLK/wq2qjqvao\nai+A3wCYZdx2s6rmVDVXU1OT7ziJqMDyCr+I9J/K9VMAnxRmOERUKoNp9W0HMBdAtYicAvDvAOaK\nyAwACqABwPIijpGIiiAYflVdNMDV9mLudJXQuv3Dh9v/DcePHzfrVi8/9LW7u7vNesj69bFv9wAA\n5syZE1u75557zGN7enrMemhOfhKhNRhC/6ciUsjhFAXP8CNyiuEncorhJ3KK4SdyiuEncorhJ3KK\nS3eXQKjtE9pGe9OmTWb9zjvvjK2Ftrk+dOiQWX/nnXfM+tdff23Wn3766dja7t27zWMrKyvN+tix\nY826JdSqC00nHgqG/ndIRANi+ImcYviJnGL4iZxi+ImcYviJnGL4iZxin78EWlpazPrRo0fNem1t\nrVnfs2dPbG3atGnmscuX20sxhKbsrlmzxqwfPnw4tnbq1Cnz2Lvvvtush1i9/NDS3KFzM4bCeQDl\n/x0QUV4YfiKnGH4ipxh+IqcYfiKnGH4ipxh+IqfY5y+B6upqs753r73J8fTp0836zTfffN1juqKz\ns9OsP/PMM2Z91KhReR8fWktg5syZZj3EWpZ8xIgReR8LhJcVD339LOAzP5FTDD+RUww/kVMMP5FT\nDD+RUww/kVMMP5FTwT6/iEwG8CaAWgAKYLOqbhSRiQD+AGAKgAYAj6lqa/GGWr6++eYbsx6a1x7a\nynr06NGxta6uLvPY0FbUoX730qVLzbq1HsAbb7xhHrtixQqzHtp+3OrFh/rwofn65bAFd8hgnvm7\nAaxW1ekA/hHAL0VkOoBnARxQ1dsBHIg+J6IyEQy/qp5R1Q+jy20APgNwC4CFALZFN9sG4JFiDZKI\nCu+6/uYXkSkAZgL4E4BaVT0Tlc6i788CIioTgw6/iIwF8EcAq1T1fP+a9i2WNuCCaSKyTETqRKSu\nubk50WCJqHAGFX4RGYG+4P9OVXdGVzeKyKSoPglA00DHqupmVc2paq6mpqYQYyaiAgiGX/re1twC\n4DNV3dCvtAfAkujyEgD2lqtElCmDmdL7YwA/B/CxiNRH1z0PYB2A/xaRpQD+BuCx4gyx/N10001m\nPdSOe/TRR826NS031MoLtSHHjx9v1ltb7e7uhQsXYmuhNmJoee2QUCvQEmrlDYVWX/DRUdWDAOK+\n058UdjhEVCo8w4/IKYafyCmGn8gphp/IKYafyCmGn8gpLt2dAdZW0gBw7Ngxs37//ffH1k6cOGEe\nu2/fPrPe2Nho1q0+PgBcunQptrZq1Srz2NA5CiFWnz90DsFQ2II7ZOh/h0Q0IIafyCmGn8gphp/I\nKYafyCmGn8gphp/IKQn1mAspl8tpXV1dye4vKzo6Osx6VVVVonp7e3tsLbR6UmhptdC89dDPj7Ue\nQOi+KyoqzHpoe3Fr+3Dr/APAXg4dyO55ArlcDnV1dYNabIDP/EROMfxETjH8RE4x/EROMfxETjH8\nRE4x/EROcT5/CYTmpb/wwgtmfePGjXnfd1tbm1m/7bbbzLq1VgAALFiwwKzPnTs3thbq04e20bb6\n+CGhPn7IUFi3n8/8RE4x/EROMfxETjH8RE4x/EROMfxETjH8RE4F5/OLyGQAbwKoBaAANqvqRhF5\nEcAvAFyZlP28qu61vpbX+fxJhdYDsOaOJ9mjHgjPqQ+xevmhOe+h+06z1z4U5vMP5iejG8BqVf1Q\nRMYBOCoiV3Z6+JWq/me+AyWi9ATDr6pnAJyJLreJyGcAbin2wIiouK7rtYmITAEwE8CfoqtWishH\nIrJVRCbEHLNMROpEpC60bBMRlc6gwy8iYwH8EcAqVT0P4NcApgGYgb5XBusHOk5VN6tqTlVzofXk\niKh0BhV+ERmBvuD/TlV3AoCqNqpqj6r2AvgNgFnFGyYRFVow/NL3luoWAJ+p6oZ+10/qd7OfAvik\n8MMjomIZzLv9PwbwcwAfi0h9dN3zABaJyAz0tf8aACwvygiHgNbWVrM+YcKAb5d8r7KyMu/77unp\nMetJW3mhr590m+2sGgpTegfzbv9BAAN9p2ZPn4iyjWf4ETnF8BM5xfATOcXwEznF8BM5xfATOcWl\nu0sg1McPTatO0qsP9fG7u7vNeqifnfQ8gXI1FPr8fOYncorhJ3KK4SdyiuEncorhJ3KK4SdyiuEn\nciq4dHdB70ykGcDf+l1VDaClZAO4PlkdW1bHBXBs+Srk2P5eVQe1Xl5Jw/+DOxepU9VcagMwZHVs\nWR0XwLHlK62x8WU/kVMMP5FTaYd/c8r3b8nq2LI6LoBjy1cqY0v1b34iSk/az/xElJJUwi8iD4jI\nFyJyXESeTWMMcUSkQUQ+FpF6EUl1S+FoG7QmEfmk33UTRWSfiHwZfbTnC5d2bC+KyOnosasXkYdS\nGttkEfk/EfmLiHwqIv8SXZ/qY2eMK5XHreQv+0WkAsBfAcwDcArAEQCLVPUvJR1IDBFpAJBT1dR7\nwiLyTwDaAbypqndF1/0HgHOqui76xTlBVf81I2N7EUB72js3RxvKTOq/szSARwA8iRQfO2NcjyGF\nxy2NZ/5ZAI6r6leq2glgB4CFKYwj81T1fQDnrrl6IYBt0eVt6PvhKbmYsWWCqp5R1Q+jy20Aruws\nnepjZ4wrFWmE/xYAJ/t9fgrZ2vJbAewXkaMisiztwQygNto2HQDOAqhNczADCO7cXErX7Cydmccu\nnx2vC41v+P3QHFWdAeBBAL+MXt5mkvb9zZalds2gdm4ulQF2lv5emo9dvjteF1oa4T8NYHK/z38U\nXZcJqno6+tgEYBeyt/tw45VNUqOPTSmP53tZ2rl5oJ2lkYHHLks7XqcR/iMAbheRqSIyEsDPAOxJ\nYRw/ICJV0RsxEJEqAPORvd2H9wBYEl1eAmB3imO5SlZ2bo7bWRopP3aZ2/FaVUv+D8BD6HvH/wSA\nf0tjDDHjmgbgz9G/T9MeG4Dt6HsZ2IW+90aWArgJwAEAXwLYD2Bihsb2FoCPAXyEvqBNSmlsc9D3\nkv4jAPXRv4fSfuyMcaXyuPEMPyKn+IYfkVMMP5FTDD+RUww/kVMMP5FTDD+RUww/kVMMP5FT/w9n\nq6jefCrr/gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x11f6c4c88>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with h5py.File('mnist_batches.h5', 'r') as h5f:\n",
    "\n",
    "    plt.imshow(h5f['train/images'][0][:, :, -1], cmap='binary');\n",
    "    print('Class label:', h5f['train/labels'][0])\n",
    "    print('File ID:', h5f['train/file_ids'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Loading Minibatches"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following cell implements a class for iterating over the MNIST images, based on the .h5 file, conveniently. \n",
    "Via the `normalize` parameter we additionally scale the image pixels to [0, 1] range, which typically helps with gradient-based optimization in practice.\n",
    "\n",
    "The key functions (here: generators) are\n",
    "\n",
    "- load_train_epoch\n",
    "- load_valid_epoch\n",
    "- load_test_epoch\n",
    "\n",
    "These let us iterate over small chunks (determined via `minibatch_size`) and yield minibatches via memory-efficient Python generators. Via the two shuffle parameters, we can further control if the images within each batch to be shuffled. By setting `onehot=True`, the labels are converted into a onehot representation for convenience."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class BatchLoader():\n",
    "    def __init__(self, minibatches_path, \n",
    "                 normalize=True):\n",
    "        \n",
    "        self.minibatches_path = minibatches_path\n",
    "        self.normalize = normalize\n",
    "        self.num_train = 45000\n",
    "        self.num_valid = 5000\n",
    "        self.num_test = 10000\n",
    "        self.n_classes = 10\n",
    "\n",
    "\n",
    "    def load_train_epoch(self, batch_size=50, onehot=False,\n",
    "                         shuffle_batch=False, prefetch_batches=1, seed=None):\n",
    "        for batch_x, batch_y in self._load_epoch(which='train',\n",
    "                                                 batch_size=batch_size,\n",
    "                                                 onehot=onehot,\n",
    "                                                 shuffle_batch=shuffle_batch,\n",
    "                                                 prefetch_batches=prefetch_batches, \n",
    "                                                 seed=seed):\n",
    "            yield batch_x, batch_y\n",
    "\n",
    "    def load_test_epoch(self, batch_size=50, onehot=False,\n",
    "                        shuffle_batch=False, prefetch_batches=1, seed=None):\n",
    "        for batch_x, batch_y in self._load_epoch(which='test',\n",
    "                                                 batch_size=batch_size,\n",
    "                                                 onehot=onehot,\n",
    "                                                 shuffle_batch=shuffle_batch,\n",
    "                                                 prefetch_batches=prefetch_batches,\n",
    "                                                 seed=seed):\n",
    "            yield batch_x, batch_y\n",
    "            \n",
    "    def load_validation_epoch(self, batch_size=50, onehot=False,\n",
    "                              shuffle_batch=False, prefetch_batches=1, seed=None):\n",
    "        for batch_x, batch_y in self._load_epoch(which='valid',\n",
    "                                                 batch_size=batch_size,\n",
    "                                                 onehot=onehot,\n",
    "                                                 shuffle_batch=shuffle_batch,\n",
    "                                                 prefetch_batches=prefetch_batches, \n",
    "                                                 seed=seed):\n",
    "            yield batch_x, batch_y\n",
    "\n",
    "    def _load_epoch(self, which='train', batch_size=50, onehot=False,\n",
    "                    shuffle_batch=False, prefetch_batches=1, seed=None):\n",
    "        \n",
    "        prefetch_size = prefetch_batches * batch_size\n",
    "        \n",
    "        if shuffle_batch:\n",
    "            rgen = np.random.RandomState(seed)\n",
    "\n",
    "        with h5py.File(self.minibatches_path, 'r') as h5f:\n",
    "            indices = np.arange(h5f['%s/images' % which].shape[0])\n",
    "            \n",
    "            for start_idx in range(0, indices.shape[0] - prefetch_size + 1,\n",
    "                                   prefetch_size):           \n",
    "            \n",
    "\n",
    "                x_batch = h5f['%s/images' % which][start_idx:start_idx + prefetch_size]\n",
    "                x_batch = x_batch.astype(np.float32)\n",
    "                y_batch = h5f['%s/labels' % which][start_idx:start_idx + prefetch_size]\n",
    "\n",
    "                if onehot:\n",
    "                    y_batch = (np.arange(self.n_classes) == \n",
    "                               y_batch[:, None]).astype(np.uint8)\n",
    "\n",
    "                if self.normalize:\n",
    "                    # normalize to [0, 1] range\n",
    "                    x_batch = x_batch.astype(np.float32) / 255.\n",
    "\n",
    "                if shuffle_batch:\n",
    "                    rand_indices = np.arange(prefetch_size)\n",
    "                    rgen.shuffle(rand_indices)\n",
    "                    x_batch = x_batch[rand_indices]\n",
    "                    y_batch = y_batch[rand_indices]\n",
    "\n",
    "                for batch_idx in range(0, x_batch.shape[0] - batch_size + 1,\n",
    "                                       batch_size):\n",
    "                    \n",
    "                    yield (x_batch[batch_idx:batch_idx + batch_size], \n",
    "                           y_batch[batch_idx:batch_idx + batch_size])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following for loop will iterate over the 45,000 training examples in our MNIST training set, yielding 50 images and labels at a time (note that we previously set aside 5000 training example as our validation datast)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50, 28, 28, 1)\n",
      "(50, 10)\n"
     ]
    }
   ],
   "source": [
    "batch_loader = BatchLoader(minibatches_path='./mnist_batches.h5', \n",
    "                           normalize=True)\n",
    "\n",
    "for batch_x, batch_y in batch_loader.load_train_epoch(batch_size=50, onehot=True):\n",
    "    print(batch_x.shape)\n",
    "    print(batch_y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "One training epoch contains 45000 images\n"
     ]
    }
   ],
   "source": [
    "cnt = 0\n",
    "for batch_x, batch_y in batch_loader.load_train_epoch(\n",
    "        batch_size=100, onehot=True):\n",
    "    cnt += batch_x.shape[0]\n",
    "    \n",
    "print('One training epoch contains %d images' % cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "323 ms ± 9.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "def one_epoch():\n",
    "    for batch_x, batch_y in batch_loader.load_train_epoch(\n",
    "            batch_size=100, onehot=True):\n",
    "        pass\n",
    "    \n",
    "% timeit one_epoch()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see from the benchmark above, an iteration over one training epoch (45k images) is relatively fast."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, we could iterate over validation and test data via \n",
    "\n",
    "- batch_loader.load_validation_epoch\n",
    "- batch_loader.load_test_epoch\n",
    "\n",
    "Note that increasing the `batch_size` can substantially improve the computationally efficiency loading an epoch, since it would lower the number of iterations. Further, we used two nested for loops in `_load_epoch`, where the inner one yields the actual batches. The purpose of the outer loop in this function is to prefetch multiple batches for shuffling -- otherwise, the shuffling won't have any effect on the gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "172 ms ± 11.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "def one_epoch():\n",
    "    for batch_x, batch_y in batch_loader.load_train_epoch(\n",
    "            batch_size=100, shuffle_batch=True, prefetch_batches=4, \n",
    "            seed=123, onehot=True):\n",
    "        pass\n",
    "    \n",
    "% timeit one_epoch()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also, as we can see from the benchmark, prefetching multiple batches from the HDF5 database can speed up the loading of an epoch. Note that this could not always practicable (for example, when we are working with high-resolution images) due to memory constraints."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training a Model using TensorFlow's `feed_dict`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code demonstrate how we can feed our minibatches into a TensorFlow graph using a TensorFlow session's `feed_dict`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multilayer Perceptron Graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "learning_rate = 0.1\n",
    "training_epochs = 15\n",
    "batch_size = 100\n",
    "\n",
    "# Architecture\n",
    "n_hidden_1 = 128\n",
    "n_hidden_2 = 256\n",
    "height, width = 28, 28\n",
    "n_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### GRAPH DEFINITION\n",
    "##########################\n",
    "\n",
    "g = tf.Graph()\n",
    "with g.as_default():\n",
    "    \n",
    "    tf.set_random_seed(123)\n",
    "\n",
    "    # Input data\n",
    "    tf_x = tf.placeholder(tf.float32, [None, height, width, 1], name='features')\n",
    "    tf_x_flat = tf.reshape(tf_x, shape=[-1, height*width])\n",
    "    tf_y = tf.placeholder(tf.int32, [None, n_classes], name='targets')\n",
    "\n",
    "    # Model parameters\n",
    "    weights = {\n",
    "        'h1': tf.Variable(tf.truncated_normal([width*height, n_hidden_1], stddev=0.1)),\n",
    "        'h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], stddev=0.1)),\n",
    "        'out': tf.Variable(tf.truncated_normal([n_hidden_2, n_classes], stddev=0.1))\n",
    "    }\n",
    "    biases = {\n",
    "        'b1': tf.Variable(tf.zeros([n_hidden_1])),\n",
    "        'b2': tf.Variable(tf.zeros([n_hidden_2])),\n",
    "        'out': tf.Variable(tf.zeros([n_classes]))\n",
    "    }\n",
    "\n",
    "    # Multilayer perceptron\n",
    "    layer_1 = tf.add(tf.matmul(tf_x_flat, weights['h1']), biases['b1'])\n",
    "    layer_1 = tf.nn.relu(layer_1)\n",
    "    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])\n",
    "    layer_2 = tf.nn.relu(layer_2)\n",
    "    out_layer = tf.matmul(layer_2, weights['out']) + biases['out']\n",
    "\n",
    "    # Loss and optimizer\n",
    "    loss = tf.nn.softmax_cross_entropy_with_logits(logits=out_layer, labels=tf_y)\n",
    "    cost = tf.reduce_mean(loss, name='cost')\n",
    "    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n",
    "    train = optimizer.minimize(cost, name='train')\n",
    "\n",
    "    # Prediction\n",
    "    correct_prediction = tf.equal(tf.argmax(tf_y, 1), tf.argmax(out_layer, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training the Neural Network with Minibatches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001 | AvgCost: 0.462 | MbTrain/Valid ACC: 0.950/0.901\n",
      "Epoch: 002 | AvgCost: 0.218 | MbTrain/Valid ACC: 0.980/0.935\n",
      "Epoch: 003 | AvgCost: 0.164 | MbTrain/Valid ACC: 0.980/0.950\n",
      "Epoch: 004 | AvgCost: 0.132 | MbTrain/Valid ACC: 1.000/0.958\n",
      "Epoch: 005 | AvgCost: 0.111 | MbTrain/Valid ACC: 1.000/0.958\n",
      "Epoch: 006 | AvgCost: 0.096 | MbTrain/Valid ACC: 0.980/0.964\n",
      "Epoch: 007 | AvgCost: 0.083 | MbTrain/Valid ACC: 0.990/0.965\n",
      "Epoch: 008 | AvgCost: 0.073 | MbTrain/Valid ACC: 0.990/0.966\n",
      "Epoch: 009 | AvgCost: 0.064 | MbTrain/Valid ACC: 1.000/0.966\n",
      "Epoch: 010 | AvgCost: 0.057 | MbTrain/Valid ACC: 0.980/0.967\n",
      "Epoch: 011 | AvgCost: 0.050 | MbTrain/Valid ACC: 1.000/0.969\n",
      "Epoch: 012 | AvgCost: 0.045 | MbTrain/Valid ACC: 1.000/0.969\n",
      "Epoch: 013 | AvgCost: 0.040 | MbTrain/Valid ACC: 0.990/0.971\n",
      "Epoch: 014 | AvgCost: 0.036 | MbTrain/Valid ACC: 1.000/0.969\n",
      "Epoch: 015 | AvgCost: 0.032 | MbTrain/Valid ACC: 1.000/0.970\n",
      "Test ACC: 0.974\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### TRAINING & EVALUATION\n",
    "##########################\n",
    "\n",
    "batch_loader = BatchLoader(minibatches_path='./mnist_batches.h5', \n",
    "                           normalize=True)\n",
    "\n",
    "# preload small validation set\n",
    "# by unpacking the generator\n",
    "[valid_data] = batch_loader.load_validation_epoch(batch_size=5000, \n",
    "                                                   onehot=True)\n",
    "valid_x, valid_y = valid_data[0], valid_data[1]\n",
    "del valid_data\n",
    "\n",
    "with tf.Session(graph=g) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "\n",
    "    for epoch in range(training_epochs):\n",
    "        avg_cost = 0.\n",
    "\n",
    "        n_batches = 0\n",
    "        for batch_x, batch_y in batch_loader.load_train_epoch(batch_size=batch_size, \n",
    "                                                              onehot=True,\n",
    "                                                              shuffle_batch=True,\n",
    "                                                              prefetch_batches=10,\n",
    "                                                              seed=epoch):\n",
    "            n_batches += 1\n",
    "            _, c = sess.run(['train', 'cost:0'], feed_dict={'features:0': batch_x,\n",
    "                                                            'targets:0': batch_y.astype(np.int)})\n",
    "            avg_cost += c\n",
    "        \n",
    "        train_acc = sess.run('accuracy:0', feed_dict={'features:0': batch_x,\n",
    "                                                      'targets:0': batch_y})\n",
    "        \n",
    "        valid_acc = sess.run('accuracy:0', feed_dict={'features:0': valid_x,\n",
    "                                                      'targets:0': valid_y})  \n",
    "        \n",
    "        print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch + 1, avg_cost / n_batches), end=\"\")\n",
    "        print(\" | MbTrain/Valid ACC: %.3f/%.3f\" % (train_acc, valid_acc))\n",
    "        \n",
    "        \n",
    "    # imagine test set is too large to fit into memory:\n",
    "    test_acc, cnt = 0., 0\n",
    "    for test_x, test_y in batch_loader.load_test_epoch(batch_size=100, \n",
    "                                                       onehot=True):   \n",
    "        cnt += 1\n",
    "        acc = sess.run(accuracy, feed_dict={'features:0': test_x,\n",
    "                                            'targets:0': test_y})\n",
    "        test_acc += acc\n",
    "    print('Test ACC: %.3f' % (test_acc / cnt))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
